package org.neuroph.nnet.learning;

import com.google.firebase.remoteconfig.FirebaseRemoteConfig;
import java.util.Iterator;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.learning.SupervisedLearning;
import org.neuroph.util.NeuralNetworkCODEC;

/* loaded from: classes2.dex */
public class SimulatedAnnealingLearning extends SupervisedLearning {
    private static final long serialVersionUID = 1;
    private double[] bestWeights;
    private int cycles;
    protected NeuralNetwork network;
    private double startTemperature;
    private double stopTemperature;
    protected double temperature;
    private double[] weights;

    public SimulatedAnnealingLearning(NeuralNetwork neuralNetwork) {
        this(neuralNetwork, 10.0d, 2.0d, 1000);
    }

    public SimulatedAnnealingLearning(NeuralNetwork neuralNetwork, double d, double d2, int i) {
        this.network = neuralNetwork;
        this.temperature = d;
        this.startTemperature = d;
        this.stopTemperature = d2;
        this.cycles = i;
        this.weights = new double[NeuralNetworkCODEC.determineArraySize(neuralNetwork)];
        this.bestWeights = new double[NeuralNetworkCODEC.determineArraySize(neuralNetwork)];
        NeuralNetworkCODEC.network2array(neuralNetwork, this.weights);
        NeuralNetworkCODEC.network2array(neuralNetwork, this.bestWeights);
    }

    private double determineError(DataSet dataSet) {
        Iterator<DataSetRow> it = dataSet.iterator();
        double d = 0.0d;
        while (it.hasNext() && !isStopped()) {
            DataSetRow next = it.next();
            this.neuralNetwork.setInput(next.getInput());
            this.neuralNetwork.calculate();
            double[] calculateOutputError = calculateOutputError(next.getDesiredOutput(), this.neuralNetwork.getOutput());
            updateTotalNetworkError(calculateOutputError);
            double d2 = 0.0d;
            for (double d3 : calculateOutputError) {
                d2 += d3 * d3;
            }
            double length = calculateOutputError.length * 2;
            Double.isNaN(length);
            d += d2 / length;
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.neuroph.core.learning.SupervisedLearning
    public void addToSquaredErrorSum(double[] dArr) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override // org.neuroph.core.learning.SupervisedLearning, org.neuroph.core.learning.IterativeLearning
    public void doLearningEpoch(DataSet dataSet) {
        double[] dArr = this.weights;
        System.arraycopy(dArr, 0, this.bestWeights, 0, dArr.length);
        double determineError = determineError(dataSet);
        this.temperature = this.startTemperature;
        for (int i = 0; i < this.cycles; i++) {
            randomize();
            double determineError2 = determineError(dataSet);
            if (determineError2 < determineError) {
                double[] dArr2 = this.weights;
                System.arraycopy(dArr2, 0, this.bestWeights, 0, dArr2.length);
                determineError = determineError2;
            } else {
                double[] dArr3 = this.bestWeights;
                double[] dArr4 = this.weights;
                System.arraycopy(dArr3, 0, dArr4, 0, dArr4.length);
            }
            NeuralNetworkCODEC.array2network(this.bestWeights, this.network);
            double log = Math.log(this.stopTemperature / this.startTemperature);
            double d = this.cycles - 1;
            Double.isNaN(d);
            this.temperature *= Math.exp(log / d);
        }
        this.previousEpochError = this.totalNetworkError;
        this.totalNetworkError = determineError;
        if (hasReachedStopCondition()) {
            stopLearning();
        }
    }

    public NeuralNetwork getNetwork() {
        return this.network;
    }

    public void randomize() {
        int i = 0;
        while (true) {
            double[] dArr = this.weights;
            if (i >= dArr.length) {
                NeuralNetworkCODEC.array2network(dArr, this.network);
                return;
            }
            double random = ((0.5d - Math.random()) / this.startTemperature) * this.temperature;
            double[] dArr2 = this.weights;
            dArr2[i] = dArr2[i] + random;
            i++;
        }
    }

    @Override // org.neuroph.core.learning.SupervisedLearning
    protected void updateNetworkWeights(double[] dArr) {
    }

    protected void updateTotalNetworkError(double[] dArr) {
        double d = FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE;
        for (double d2 : dArr) {
            d += d2 * d2;
        }
        double d3 = this.totalNetworkError;
        double length = dArr.length * 2;
        Double.isNaN(length);
        this.totalNetworkError = d3 + (d / length);
    }
}
